"""
Visualization module for cognitive maps and evaluation results.

This module provides functions to:
1. Visualize cognitive maps
2. Plot evaluation results
3. Generate comparison visualizations
"""

import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from matplotlib.lines import Line2D
import networkx as nx
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional, Any, Union

def plot_cognitive_map(cogmap: Dict, title: str = "Cognitive Map", figsize: Tuple[int, int] = (10, 10)) -> Figure:
    """
    Visualize a cognitive map as a 2D plot.
    
    Args:
        cogmap: The cognitive map dictionary
        title: The plot title
        figsize: Figure size as (width, height)
        
    Returns:
        The matplotlib figure
    """
    fig, ax = plt.subplots(figsize=figsize)
    
    # Extract objects based on format
    if "objects" in cogmap and isinstance(cogmap["objects"], list):
        # Complex format
        objects = {}
        for obj in cogmap.get("objects", []):
            if "name" in obj and "position" in obj:
                objects[obj["name"]] = {
                    "position": obj["position"],
                    "facing": obj.get("facing", None),
                    "type": "object"
                }
        
        # Add views if present
        for view in cogmap.get("views", []):
            if "name" in view and "position" in view:
                objects[view["name"]] = {
                    "position": view["position"],
                    "facing": view.get("facing", None),
                    "type": "view"
                }
    else:
        # Simple format
        objects = {}
        for name, data in cogmap.items():
            if isinstance(data, dict) and "position" in data:
                objects[name] = {
                    "position": data["position"],
                    "facing": data.get("facing", None),
                    "type": "object"
                }
    
    # Set up colors and markers
    colors = {"object": "blue", "view": "green"}
    
    # Plot objects
    for name, data in objects.items():
        pos = data["position"]
        facing = data["facing"]
        obj_type = data.get("type", "object")
        
        # Plot object position
        ax.scatter(pos[0], pos[1], color=colors.get(obj_type, "blue"), s=100)
        
        # Add object name label
        ax.annotate(name, (pos[0], pos[1]), xytext=(5, 5), textcoords="offset points")
        
        # Draw arrow for facing direction if available
        if facing:
            draw_facing_arrow(ax, pos, facing)
    
    # Set plot properties
    ax.set_title(title)
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.grid(True)
    ax.axis("equal")  # Equal aspect ratio
    
    # Add legend
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', markersize=10, label='Object'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='green', markersize=10, label='View')
    ]
    ax.legend(handles=legend_elements)
    
    return fig

def draw_facing_arrow(ax: Axes, pos: List[float], facing: str, arrow_length: float = 0.5) -> None:
    """
    Draw an arrow indicating the facing direction.
    
    Args:
        ax: Matplotlib axes
        pos: The position coordinates [x, y]
        facing: The facing direction ('up', 'down', 'left', 'right', 'inner', 'outer')
        arrow_length: Length of the arrow
    """
    direction_map = {
        "up": (0, 1),
        "down": (0, -1),
        "left": (-1, 0),
        "right": (1, 0),
        "inner": (0.2, 0.2),    # Diagonal inward
        "outer": (-0.2, -0.2)   # Diagonal outward
    }
    
    if facing in direction_map:
        dx, dy = direction_map[facing]
        ax.arrow(
            pos[0], pos[1], dx * arrow_length, dy * arrow_length,
            head_width=0.1, head_length=0.1, fc='black', ec='black'
        )

def plot_cognitive_map_comparison(generated_map: Dict, ground_truth_map: Dict, 
                              similarity_metrics: Optional[Dict] = None,
                              figsize: Tuple[int, int] = (15, 7)) -> Figure:
    """
    Plot a side-by-side comparison of generated and ground truth cognitive maps.
    
    Args:
        generated_map: The generated cognitive map
        ground_truth_map: The ground truth cognitive map
        similarity_metrics: Optional dictionary of similarity metrics to display
        figsize: Figure size as (width, height)
        
    Returns:
        The matplotlib figure
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
    
    # Plot the two maps
    plot_map_on_axis(ax1, generated_map, "Generated Cognitive Map")
    plot_map_on_axis(ax2, ground_truth_map, "Ground Truth Cognitive Map")
    
    # Add similarity metrics if provided
    if similarity_metrics:
        metrics_text = "\n".join([
            f"{key.replace('_', ' ').title()}: {value:.4f}" 
            for key, value in similarity_metrics.items()
            if isinstance(value, (int, float)) and key != "valid_graph"
        ])
        
        # Add isomorphic status
        if "rotation_invariant_isomorphic" in similarity_metrics:
            isomorphic = similarity_metrics["rotation_invariant_isomorphic"]
            metrics_text = f"Isomorphic: {isomorphic}\n" + metrics_text
        
        fig.text(0.5, 0.01, metrics_text, ha='center', va='bottom', 
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout(rect=(0, 0.1, 1, 0.95))  # Leave space for metrics at bottom
    return fig

def plot_map_on_axis(ax: Axes, cogmap: Dict, title: str) -> None:
    """
    Plot a cognitive map on a specific axis.
    
    Args:
        ax: The matplotlib axis to plot on
        cogmap: The cognitive map dictionary
        title: The plot title
    """
    # Extract objects based on format
    if "objects" in cogmap and isinstance(cogmap["objects"], list):
        # Complex format
        objects = {}
        for obj in cogmap.get("objects", []):
            if "name" in obj and "position" in obj:
                objects[obj["name"]] = {
                    "position": obj["position"],
                    "facing": obj.get("facing", None),
                    "type": "object"
                }
        
        # Add views if present
        for view in cogmap.get("views", []):
            if "name" in view and "position" in view:
                objects[view["name"]] = {
                    "position": view["position"],
                    "facing": view.get("facing", None),
                    "type": "view"
                }
    else:
        # Simple format
        objects = {}
        for name, data in cogmap.items():
            if isinstance(data, dict) and "position" in data:
                objects[name] = {
                    "position": data["position"],
                    "facing": data.get("facing", None),
                    "type": "object"
                }
    
    # Set up colors and markers
    colors = {"object": "blue", "view": "green"}
    
    # Plot objects
    for name, data in objects.items():
        pos = data["position"]
        facing = data["facing"]
        obj_type = data.get("type", "object")
        
        # Plot object position
        ax.scatter(pos[0], pos[1], color=colors.get(obj_type, "blue"), s=100)
        
        # Add object name label
        ax.annotate(name, (pos[0], pos[1]), xytext=(5, 5), textcoords="offset points")
        
        # Draw arrow for facing direction if available
        if facing:
            draw_facing_arrow(ax, pos, facing)
    
    # Set plot properties
    ax.set_title(title)
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.grid(True)
    ax.axis("equal")  # Equal aspect ratio

def plot_results_by_setting(results: Dict, figsize: Tuple[int, int] = (12, 10)) -> Figure:
    """
    Plot evaluation results broken down by setting.
    
    Args:
        results: The evaluation results dictionary
        figsize: Figure size as (width, height)
        
    Returns:
        The matplotlib figure
    """
    fig, axs = plt.subplots(2, 2, figsize=figsize)
    axs = axs.flatten()
    
    # Extract data for each setting
    settings = ['around', 'rotation', 'translation', 'among']
    metrics = ['gen_cogmap_accuracy', 'valid_percent', 'avg_relative_position_accuracy', 'avg_facing_similarity']
    metric_labels = ['Answer Accuracy', 'Valid Maps (%)', 'Position Accuracy', 'Facing Similarity']
    
    # Prepare data for plotting
    metric_data = {metric: [] for metric in metrics}
    for setting in settings:
        stats = results['settings'][setting]
        if stats['total'] > 0:
            metric_data['gen_cogmap_accuracy'].append(stats['gen_cogmap_accuracy'])
            metric_data['valid_percent'].append(stats['cogmap_similarity']['valid_percent'] / 100.0)  # Convert to fraction
            metric_data['avg_relative_position_accuracy'].append(stats['cogmap_similarity']['avg_relative_position_accuracy'])
            metric_data['avg_facing_similarity'].append(stats['cogmap_similarity']['avg_facing_similarity'])
        else:
            # Use zero for settings with no data
            for metric in metrics:
                metric_data[metric].append(0)
    
    # Create bar plots for each metric
    for i, (metric, label) in enumerate(zip(metrics, metric_labels)):
        axs[i].bar(settings, metric_data[metric])
        axs[i].set_title(label)
        axs[i].set_ylim(0, 1)  # All metrics are 0-1
        for j, v in enumerate(metric_data[metric]):
            axs[i].text(j, v + 0.02, f"{v:.2f}", ha='center')
    
    plt.tight_layout()
    return fig

def plot_model_comparison(results_list: List[Dict], metric: str = "valid_percent", figsize: Tuple[int, int] = (12, 8)) -> Figure:
    """
    Plot comparison of multiple models based on a specific metric.
    
    Args:
        results_list: List of result dictionaries from different models
        metric: The metric to compare ('valid_percent', 'gen_cogmap_accuracy', etc.)
        figsize: Figure size as (width, height)
        
    Returns:
        The matplotlib figure
    """
    # Extract model names and metric values
    model_names = [f"{result['model']}-{result['version']}" for result in results_list]
    
    # Extract metric values based on path
    metric_values = []
    for result in results_list:
        if metric == "valid_percent":
            value = result['cogmap_similarity']['valid_percent'] / 100.0  # Convert to fraction
        elif metric == "isomorphic_rate":
            total = result['total']
            value = result['cogmap_similarity']['isomorphic_count'] / total if total > 0 else 0
        elif metric in result:
            value = result[metric]
        else:
            value = 0  # Default if metric not found
        metric_values.append(value)
    
    # Create the figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Create bar plot
    bars = ax.bar(model_names, metric_values)
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
                f"{height:.2f}", ha='center', va='bottom')
    
    # Set title and labels
    metric_label = metric.replace('_', ' ').title()
    ax.set_title(f"Model Comparison: {metric_label}")
    ax.set_xlabel("Model")
    ax.set_ylabel(metric_label)
    ax.set_ylim(0, max(metric_values) + 0.1)
    
    # Rotate x labels if there are many models
    if len(model_names) > 4:
        plt.xticks(rotation=45, ha='right')
    
    plt.tight_layout()
    return fig

def save_plot(fig: Figure, filename: str, dpi: int = 300) -> None:
    """
    Save a matplotlib figure to a file.
    
    Args:
        fig: The matplotlib figure
        filename: Output filename
        dpi: Resolution in dots per inch
    """
    import os
    
    # Ensure directory exists
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    
    # Save the figure
    fig.savefig(filename, dpi=dpi, bbox_inches='tight')
    print(f"Plot saved to {filename}")
    
    # Close the figure to free memory
    plt.close(fig) 